04. 练习:TensorFlow 输入

输入

在上一小节中,你向 session 传入一个 tensor 并返回结果。如果你想使用一个非常量(non-constant)该怎么办?这就是 tf.placeholder()feed_dict 派上用场的时候了。这一节将向你讲解向 TensorFlow 传输数据的基础知识。

tf.placeholder()

很遗憾,你不能把数据集赋值给 x 再将它传给 TensorFlow。因为之后你会想要你的 TensorFlow 模型对不同的数据集采用不同的参数。你需要的是 tf.placeholder()

数据经过 tf.session.run() 函数得到的值,由 tf.placeholder() 返回成一个 tensor,这样你可以在 session 运行之前,设置输入。

Session 的 feed_dict

x = tf.placeholder(tf.string)

with tf.Session() as sess:
    output = sess.run(x, feed_dict={x: 'Hello World'})

tf.session.run() 里的 feed_dict 参数设置占位 tensor。上面的例子显示 tensor x 被设置成字符串 "Hello, world"。如下所示,也可以用 feed_dict 设置多个 tensor。

x = tf.placeholder(tf.string)
y = tf.placeholder(tf.int32)
z = tf.placeholder(tf.float32)

with tf.Session() as sess:
    output = sess.run(x, feed_dict={x: 'Test String', y: 123, z: 45.67})

注意:

如果传入 feed_dict 的数据与 tensor 类型不符,就无法被正确处理,你会得到 “ValueError: invalid literal for…”。

练习

让我们看看你对 tf.placeholder()feed_dict 的理解如何。下面的代码有一个报错,但是我想让你修复代码并使其返回数字 123。修改第 11 行,使代码返回数字 123

Start Quiz:

# Solution is available in the other "solution.py" tab
import tensorflow as tf


def run():
    output = None
    x = tf.placeholder(tf.int32)

    with tf.Session() as sess:
        # TODO: Feed the x tensor 123
        output = sess.run(x)

    return output
# Quiz Solution
# Note: You can't run code in this tab
import tensorflow as tf


def run():
    output = None
    x = tf.placeholder(tf.int32)

    with tf.Session() as sess:
        output = sess.run(x, feed_dict={x: 123})

    return output